// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/base/network_interfaces_win.h"

#include <algorithm>
#include <memory>

#include "base/files/file_path.h"
#include "base/lazy_instance.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_util.h"
#include "base/strings/sys_string_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "base/threading/thread_restrictions.h"
#include "base/win/scoped_handle.h"
#include "net/base/escape.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "url/gurl.h"

namespace net {

namespace {

    // Converts Windows defined types to NetworkInterfaceType.
    NetworkChangeNotifier::ConnectionType GetNetworkInterfaceType(DWORD ifType)
    {
        NetworkChangeNotifier::ConnectionType type = NetworkChangeNotifier::CONNECTION_UNKNOWN;
        if (ifType == IF_TYPE_ETHERNET_CSMACD) {
            type = NetworkChangeNotifier::CONNECTION_ETHERNET;
        } else if (ifType == IF_TYPE_IEEE80211) {
            type = NetworkChangeNotifier::CONNECTION_WIFI;
        }
        // TODO(mallinath) - Cellular?
        return type;
    }

    // Returns scoped_ptr to WLAN_CONNECTION_ATTRIBUTES. The scoped_ptr may hold a
    // NULL pointer if WLAN_CONNECTION_ATTRIBUTES is unavailable.
    std::unique_ptr<WLAN_CONNECTION_ATTRIBUTES, internal::WlanApiDeleter>
    GetConnectionAttributes()
    {
        const internal::WlanApi& wlanapi = internal::WlanApi::GetInstance();
        std::unique_ptr<WLAN_CONNECTION_ATTRIBUTES, internal::WlanApiDeleter>
            wlan_connection_attributes;
        if (!wlanapi.initialized)
            return wlan_connection_attributes;

        internal::WlanHandle client;
        DWORD cur_version = 0;
        const DWORD kMaxClientVersion = 2;
        DWORD result = wlanapi.OpenHandle(kMaxClientVersion, &cur_version, &client);
        if (result != ERROR_SUCCESS)
            return wlan_connection_attributes;

        WLAN_INTERFACE_INFO_LIST* interface_list_ptr = NULL;
        result = wlanapi.enum_interfaces_func(client.Get(), NULL, &interface_list_ptr);
        if (result != ERROR_SUCCESS)
            return wlan_connection_attributes;
        std::unique_ptr<WLAN_INTERFACE_INFO_LIST, internal::WlanApiDeleter>
            interface_list(interface_list_ptr);

        // Assume at most one connected wifi interface.
        WLAN_INTERFACE_INFO* info = NULL;
        for (unsigned i = 0; i < interface_list->dwNumberOfItems; ++i) {
            if (interface_list->InterfaceInfo[i].isState == wlan_interface_state_connected) {
                info = &interface_list->InterfaceInfo[i];
                break;
            }
        }

        if (info == NULL)
            return wlan_connection_attributes;

        WLAN_CONNECTION_ATTRIBUTES* conn_info_ptr = nullptr;
        DWORD conn_info_size = 0;
        WLAN_OPCODE_VALUE_TYPE op_code;
        result = wlanapi.query_interface_func(
            client.Get(), &info->InterfaceGuid, wlan_intf_opcode_current_connection,
            NULL, &conn_info_size, reinterpret_cast<VOID**>(&conn_info_ptr),
            &op_code);
        wlan_connection_attributes.reset(conn_info_ptr);
        if (result == ERROR_SUCCESS)
            DCHECK(conn_info_ptr);
        else
            wlan_connection_attributes.reset();
        return wlan_connection_attributes;
    }

} // namespace

namespace internal {

    base::LazyInstance<WlanApi>::Leaky lazy_wlanapi = LAZY_INSTANCE_INITIALIZER;

    WlanApi& WlanApi::GetInstance()
    {
        return lazy_wlanapi.Get();
    }

    WlanApi::WlanApi()
        : initialized(false)
    {
        // Use an absolute path to load the DLL to avoid DLL preloading attacks.
        static const wchar_t* const kDLL = L"%WINDIR%\\system32\\wlanapi.dll";
        wchar_t path[MAX_PATH] = { 0 };
        ExpandEnvironmentStrings(kDLL, path, arraysize(path));
        module = ::LoadLibraryEx(path, NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
        if (!module)
            return;

        open_handle_func = reinterpret_cast<WlanOpenHandleFunc>(
            ::GetProcAddress(module, "WlanOpenHandle"));
        enum_interfaces_func = reinterpret_cast<WlanEnumInterfacesFunc>(
            ::GetProcAddress(module, "WlanEnumInterfaces"));
        query_interface_func = reinterpret_cast<WlanQueryInterfaceFunc>(
            ::GetProcAddress(module, "WlanQueryInterface"));
        set_interface_func = reinterpret_cast<WlanSetInterfaceFunc>(
            ::GetProcAddress(module, "WlanSetInterface"));
        free_memory_func = reinterpret_cast<WlanFreeMemoryFunc>(
            ::GetProcAddress(module, "WlanFreeMemory"));
        close_handle_func = reinterpret_cast<WlanCloseHandleFunc>(
            ::GetProcAddress(module, "WlanCloseHandle"));
        initialized = open_handle_func && enum_interfaces_func && query_interface_func && set_interface_func && free_memory_func && close_handle_func;
    }

    bool GetNetworkListImpl(NetworkInterfaceList* networks,
        int policy,
        const IP_ADAPTER_ADDRESSES* adapters)
    {
        for (const IP_ADAPTER_ADDRESSES* adapter = adapters; adapter != NULL;
             adapter = adapter->Next) {
            // Ignore the loopback device.
            if (adapter->IfType == IF_TYPE_SOFTWARE_LOOPBACK) {
                continue;
            }

            if (adapter->OperStatus != IfOperStatusUp) {
                continue;
            }

            // Ignore any HOST side vmware adapters with a description like:
            // VMware Virtual Ethernet Adapter for VMnet1
            // but don't ignore any GUEST side adapters with a description like:
            // VMware Accelerated AMD PCNet Adapter #2
            if ((policy & EXCLUDE_HOST_SCOPE_VIRTUAL_INTERFACES) && strstr(adapter->AdapterName, "VMnet") != NULL) {
                continue;
            }

            for (IP_ADAPTER_UNICAST_ADDRESS* address = adapter->FirstUnicastAddress;
                 address; address = address->Next) {
                int family = address->Address.lpSockaddr->sa_family;
                if (family == AF_INET || family == AF_INET6) {
                    IPEndPoint endpoint;
                    if (endpoint.FromSockAddr(address->Address.lpSockaddr,
                            address->Address.iSockaddrLength)) {
                        size_t prefix_length = address->OnLinkPrefixLength;

                        // If the duplicate address detection (DAD) state is not changed to
                        // Preferred, skip this address.
                        if (address->DadState != IpDadStatePreferred) {
                            continue;
                        }

                        uint32_t index = (family == AF_INET) ? adapter->IfIndex : adapter->Ipv6IfIndex;

                        // From http://technet.microsoft.com/en-us/ff568768(v=vs.60).aspx, the
                        // way to identify a temporary IPv6 Address is to check if
                        // PrefixOrigin is equal to IpPrefixOriginRouterAdvertisement and
                        // SuffixOrigin equal to IpSuffixOriginRandom.
                        int ip_address_attributes = IP_ADDRESS_ATTRIBUTE_NONE;
                        if (family == AF_INET6) {
                            if (address->PrefixOrigin == IpPrefixOriginRouterAdvertisement && address->SuffixOrigin == IpSuffixOriginRandom) {
                                ip_address_attributes |= IP_ADDRESS_ATTRIBUTE_TEMPORARY;
                            }
                            if (address->PreferredLifetime == 0) {
                                ip_address_attributes |= IP_ADDRESS_ATTRIBUTE_DEPRECATED;
                            }
                        }
                        networks->push_back(NetworkInterface(
                            adapter->AdapterName,
                            base::SysWideToNativeMB(adapter->FriendlyName), index,
                            GetNetworkInterfaceType(adapter->IfType), endpoint.address(),
                            prefix_length, ip_address_attributes));
                    }
                }
            }
        }
        return true;
    }

} // namespace internal

bool GetNetworkList(NetworkInterfaceList* networks, int policy)
{
    // Max number of times to retry GetAdaptersAddresses due to
    // ERROR_BUFFER_OVERFLOW. If GetAdaptersAddresses returns this indefinitely
    // due to an unforseen reason, we don't want to be stuck in an endless loop.
    static constexpr int MAX_GETADAPTERSADDRESSES_TRIES = 10;
    // Use an initial buffer size of 15KB, as recommended by MSDN. See:
    // https://msdn.microsoft.com/en-us/library/windows/desktop/aa365915(v=vs.85).aspx
    static constexpr int INITIAL_BUFFER_SIZE = 15000;

    ULONG len = INITIAL_BUFFER_SIZE;
    ULONG flags = 0;
    // Initial buffer allocated on stack.
    char initial_buf[INITIAL_BUFFER_SIZE];
    // Dynamic buffer in case initial buffer isn't large enough.
    std::unique_ptr<char[]> buf;

    // GetAdaptersAddresses() may require IO operations.
    base::ThreadRestrictions::AssertIOAllowed();

    IP_ADAPTER_ADDRESSES* adapters = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(&initial_buf);
    ULONG result = GetAdaptersAddresses(AF_UNSPEC, flags, nullptr, adapters, &len);

    // If we get ERROR_BUFFER_OVERFLOW, call GetAdaptersAddresses in a loop,
    // because the required size may increase between successive calls, resulting
    // in ERROR_BUFFER_OVERFLOW multiple times.
    for (int tries = 1; result == ERROR_BUFFER_OVERFLOW && tries < MAX_GETADAPTERSADDRESSES_TRIES;
         ++tries) {
        buf.reset(new char[len]);
        adapters = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(buf.get());
        result = GetAdaptersAddresses(AF_UNSPEC, flags, nullptr, adapters, &len);
    }

    if (result == ERROR_NO_DATA) {
        // There are 0 networks.
        return true;
    } else if (result != NO_ERROR) {
        LOG(ERROR) << "GetAdaptersAddresses failed: " << result;
        return false;
    }

    return internal::GetNetworkListImpl(networks, policy, adapters);
}

WifiPHYLayerProtocol GetWifiPHYLayerProtocol()
{
    auto conn_info = GetConnectionAttributes();

    if (!conn_info.get())
        return WIFI_PHY_LAYER_PROTOCOL_NONE;

    switch (conn_info->wlanAssociationAttributes.dot11PhyType) {
    case dot11_phy_type_fhss:
        return WIFI_PHY_LAYER_PROTOCOL_ANCIENT;
    case dot11_phy_type_dsss:
        return WIFI_PHY_LAYER_PROTOCOL_B;
    case dot11_phy_type_irbaseband:
        return WIFI_PHY_LAYER_PROTOCOL_ANCIENT;
    case dot11_phy_type_ofdm:
        return WIFI_PHY_LAYER_PROTOCOL_A;
    case dot11_phy_type_hrdsss:
        return WIFI_PHY_LAYER_PROTOCOL_B;
    case dot11_phy_type_erp:
        return WIFI_PHY_LAYER_PROTOCOL_G;
    case dot11_phy_type_ht:
        return WIFI_PHY_LAYER_PROTOCOL_N;
    default:
        return WIFI_PHY_LAYER_PROTOCOL_UNKNOWN;
    }
}

// Note: There is no need to explicitly set the options back
// as the OS will automatically set them back when the WlanHandle
// is closed.
class WifiOptionSetter : public ScopedWifiOptions {
public:
    WifiOptionSetter(int options)
    {
        const internal::WlanApi& wlanapi = internal::WlanApi::GetInstance();
        if (!wlanapi.initialized)
            return;

        DWORD cur_version = 0;
        const DWORD kMaxClientVersion = 2;
        DWORD result = wlanapi.OpenHandle(
            kMaxClientVersion, &cur_version, &client_);
        if (result != ERROR_SUCCESS)
            return;

        WLAN_INTERFACE_INFO_LIST* interface_list_ptr = NULL;
        result = wlanapi.enum_interfaces_func(client_.Get(), NULL,
            &interface_list_ptr);
        if (result != ERROR_SUCCESS)
            return;
        std::unique_ptr<WLAN_INTERFACE_INFO_LIST, internal::WlanApiDeleter>
            interface_list(interface_list_ptr);

        for (unsigned i = 0; i < interface_list->dwNumberOfItems; ++i) {
            WLAN_INTERFACE_INFO* info = &interface_list->InterfaceInfo[i];
            if (options & WIFI_OPTIONS_DISABLE_SCAN) {
                BOOL data = false;
                wlanapi.set_interface_func(client_.Get(),
                    &info->InterfaceGuid,
                    wlan_intf_opcode_background_scan_enabled,
                    sizeof(data),
                    &data,
                    NULL);
            }
            if (options & WIFI_OPTIONS_MEDIA_STREAMING_MODE) {
                BOOL data = true;
                wlanapi.set_interface_func(client_.Get(),
                    &info->InterfaceGuid,
                    wlan_intf_opcode_media_streaming_mode,
                    sizeof(data),
                    &data,
                    NULL);
            }
        }
    }

private:
    internal::WlanHandle client_;
};

std::unique_ptr<ScopedWifiOptions> SetWifiOptions(int options)
{
    return std::unique_ptr<ScopedWifiOptions>(new WifiOptionSetter(options));
}

std::string GetWifiSSID()
{
    auto conn_info = GetConnectionAttributes();

    if (!conn_info.get())
        return "";

    const DOT11_SSID dot11_ssid = conn_info->wlanAssociationAttributes.dot11Ssid;
    return std::string(reinterpret_cast<const char*>(dot11_ssid.ucSSID),
        dot11_ssid.uSSIDLength);
}

} // namespace net
